--- title: Training loop keywords: fastai sidebar: home_sidebar nb_path: "nbs/14_train_ae.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
%load_ext autoreload
%autoreload 2
{% endraw %} {% raw %}
{% endraw %} {% raw %}
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
{% endraw %} {% raw %}

train_ae[source]

train_ae(model, dl, num_iter, optim_net, optim_psf, sched_net, sched_psf, min_int, microscope, log_interval, save_dir, log_dir, psf=None, bl_loss_scale=0.01, p_quantile=0, grad_clip=0.01, eval_dict=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
cfg = OmegaConf.load('../config/experiment/N2_352_2.yaml')
# cfg =OmegaConf.load(default_conf)
psf, noise, micro = load_psf_micro_psf_noise(cfg)
{% endraw %} {% raw %}
img_3d = load_tiff_image(cfg.data_path.image_path)[0]
estimate_backg = hydra.utils.instantiate(cfg.bg_estimation)
roi_mask       = get_roi_mask(img_3d, tuple(cfg.roi_mask.pool_size), percentile= cfg.roi_mask.percentile)
rand_crop      = RandomCrop3D((cfg.random_crop.crop_sz,cfg.random_crop.crop_sz,cfg.random_crop.crop_sz), roi_mask)
{% endraw %} {% raw %}
probmap_generator = ScaleTensor(low=cfg.prob_generator.low, 
                                high=cfg.prob_generator.high, 
                                data_min = img_3d.min(), 
                                data_max = img_3d.max())

ds = DecodeDataset(path = cfg.data_path.image_path,
                   dataset_tfms =  [rand_crop], 
                   rate_transform = probmap_generator, 
                   bg_transform = estimate_backg, 
                   device='cuda:0', 
                   num_iter=cfg.dataloader.num_iter * cfg.dataloader.bs) 

decode_dl = DataLoader(ds, batch_size=2, num_workers=0)
{% endraw %} {% raw %}
inp_offset, inp_scale = get_forward_scaling(img_3d)
micro = Microscope(parametric_psf=[psf], noise=noise, multipl=cfg.microscope.multipl).cuda()

psf  .to('cuda')
micro.to('cuda')
Microscope(
  (noise): sCMOS()
)
{% endraw %} {% raw %}
plot_3d_projections(psf.psf_volume[0])
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>], dtype=object)
{% endraw %} {% raw %}
    if cfg.evaluation is not None:
        eval_dict = dict(cfg.evaluation)
        eval_dict['crop_sl'] = eval(eval_dict['crop_sl'],{'__builtins__': None},{'s_': np.s_})
        eval_dict['px_size'] = list(eval_dict['px_size'])
    else:
        eval_dict = None
    
save_dir = Path(cfg.output.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)

OmegaConf.save(cfg, cfg.output.save_dir + '/train.yaml')
{% endraw %} {% raw %}
model_sl = load_model_state(cfg, 'model_sl.pkl').cuda()
micro.load_state_dict(torch.load(Path(cfg.output.save_dir)/'microscope_sl.pkl'))
opt_sl = AdamW(model_sl.parameters(), lr=cfg.supervised.lr)
opt_sl.load_state_dict(torch.load(Path(cfg.output.save_dir)/'opt_sl.pkl'))
scheduler_sl = torch.optim.lr_scheduler.StepLR(opt_sl, step_size=1000, gamma=0.5)

ae_param = list(micro.parameters())  + list(psf.parameters()) + list(model_sl.parameters())
opt_ae  = AdamW(ae_param, lr=1e-4)
scheduler_ae = torch.optim.lr_scheduler.StepLR(opt_ae, step_size=1000, gamma=0.5)
{% endraw %} {% raw %}
gt_img, gt_df = load_from_eval_dict(eval_dict)
{% endraw %} {% raw %}
with torch.no_grad():
    res_gt = model_sl(gt_img[None].cuda())
    locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae = model_output_to_micro_input(res_gt, threshold=0.1)
    ae_img = micro(locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae)
    pred_gt_df = model_output_to_df(res_gt, 0.1, px_size=eval_dict['px_size'])
    free_mem()

    gt_fig = gt_plot(gt_img, pred_gt_df, gt_df, eval_dict['px_size'],ae_img[0]+res_gt['background'][0])
    plt.show()
{% endraw %} {% raw %}
train_ae(model=model_sl, 
         dl=decode_dl, 
         num_iter=cfg.autoencoder.num_iter,
         optim_net=opt_sl, 
         optim_psf=opt_ae, 
         min_int=cfg.pointprocess.min_int, 
         psf=psf,
         sched_net=scheduler_sl, 
         sched_psf=scheduler_ae, 
         microscope=micro, 
         log_interval=cfg.supervised.log_interval,  
         save_dir=cfg.output.save_dir,
         log_dir=cfg.output.log_dir,
         bl_loss_scale=cfg.supervised.bl_loss_scale,
         p_quantile=cfg.supervised.p_quantile,
         grad_clip=cfg.supervised.grad_clip,
         eval_dict=eval_dict)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-15-fe3b170c87b6> in <module>
     15          p_quantile=cfg.supervised.p_quantile,
     16          grad_clip=cfg.supervised.grad_clip,
---> 17          eval_dict=eval_dict)

<ipython-input-5-2c2c0bffcea2> in train_ae(model, dl, num_iter, optim_net, optim_psf, sched_net, sched_psf, min_int, microscope, log_interval, save_dir, log_dir, psf, bl_loss_scale, p_quantile, grad_clip, eval_dict)
     45 
     46         if grad_clip > 0:
---> 47             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip, norm_type=2)
     48 
     49         optim_net.step()

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/nn/utils/clip_grad.py in clip_grad_norm_(parameters, max_norm, norm_type)
     36         total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
     37     clip_coef = max_norm / (total_norm + 1e-6)
---> 38     if clip_coef < 1:
     39         for p in parameters:
     40             p.grad.detach().mul_(clip_coef.to(p.grad.device))

KeyboardInterrupt: 
{% endraw %} {% raw %}
!nbdev_build_lib
{% endraw %}